{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ccad9292",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "29b29c4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import transformers\n",
    "from transformers import BertTokenizer, AutoModelForMaskedLM, BertConfig, BertForMaskedLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdfad0b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForMaskedLM.from_pretrained('bert-base-multilingual-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "38b560db",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['T_destination',\n",
       " '__annotations__',\n",
       " '__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattr__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_apply',\n",
       " '_auto_class',\n",
       " '_backward_compatibility_gradient_checkpointing',\n",
       " '_backward_hooks',\n",
       " '_buffers',\n",
       " '_call_impl',\n",
       " '_can_retrieve_inputs_from_name',\n",
       " '_convert_head_mask_to_5d',\n",
       " '_create_or_get_repo',\n",
       " '_expand_inputs_for_generation',\n",
       " '_forward_hooks',\n",
       " '_forward_pre_hooks',\n",
       " '_from_config',\n",
       " '_get_backward_hooks',\n",
       " '_get_decoder_start_token_id',\n",
       " '_get_logits_processor',\n",
       " '_get_logits_warper',\n",
       " '_get_name',\n",
       " '_get_repo_url_from_name',\n",
       " '_get_resized_embeddings',\n",
       " '_get_resized_lm_head',\n",
       " '_get_stopping_criteria',\n",
       " '_hook_rss_memory_post_forward',\n",
       " '_hook_rss_memory_pre_forward',\n",
       " '_init_weights',\n",
       " '_is_full_backward_hook',\n",
       " '_keys_to_ignore_on_load_missing',\n",
       " '_keys_to_ignore_on_load_unexpected',\n",
       " '_keys_to_ignore_on_save',\n",
       " '_load_from_state_dict',\n",
       " '_load_pretrained_model',\n",
       " '_load_pretrained_model_low_mem',\n",
       " '_load_state_dict_pre_hooks',\n",
       " '_maybe_warn_non_full_backward_hook',\n",
       " '_merge_criteria_processor_list',\n",
       " '_modules',\n",
       " '_named_members',\n",
       " '_non_persistent_buffers_set',\n",
       " '_parameters',\n",
       " '_prepare_attention_mask_for_generation',\n",
       " '_prepare_decoder_input_ids_for_generation',\n",
       " '_prepare_encoder_decoder_kwargs_for_generation',\n",
       " '_prepare_input_ids_for_generation',\n",
       " '_prepare_model_inputs',\n",
       " '_push_to_hub',\n",
       " '_register_load_state_dict_pre_hook',\n",
       " '_register_state_dict_hook',\n",
       " '_reorder_cache',\n",
       " '_replicate_for_data_parallel',\n",
       " '_resize_token_embeddings',\n",
       " '_save_to_state_dict',\n",
       " '_set_default_torch_dtype',\n",
       " '_set_gradient_checkpointing',\n",
       " '_slow_forward',\n",
       " '_state_dict_hooks',\n",
       " '_tie_encoder_decoder_weights',\n",
       " '_tie_or_clone_weights',\n",
       " '_update_model_kwargs_for_generation',\n",
       " '_version',\n",
       " 'add_memory_hooks',\n",
       " 'add_module',\n",
       " 'adjust_logits_during_generation',\n",
       " 'apply',\n",
       " 'base_model',\n",
       " 'base_model_prefix',\n",
       " 'beam_sample',\n",
       " 'beam_search',\n",
       " 'bert',\n",
       " 'bfloat16',\n",
       " 'buffers',\n",
       " 'children',\n",
       " 'cls',\n",
       " 'compute_transition_beam_scores',\n",
       " 'config',\n",
       " 'config_class',\n",
       " 'constrained_beam_search',\n",
       " 'cpu',\n",
       " 'create_extended_attention_mask_for_decoder',\n",
       " 'cuda',\n",
       " 'device',\n",
       " 'double',\n",
       " 'dtype',\n",
       " 'dummy_inputs',\n",
       " 'dump_patches',\n",
       " 'estimate_tokens',\n",
       " 'eval',\n",
       " 'extra_repr',\n",
       " 'float',\n",
       " 'floating_point_ops',\n",
       " 'forward',\n",
       " 'framework',\n",
       " 'from_pretrained',\n",
       " 'generate',\n",
       " 'get_buffer',\n",
       " 'get_extended_attention_mask',\n",
       " 'get_extra_state',\n",
       " 'get_head_mask',\n",
       " 'get_input_embeddings',\n",
       " 'get_output_embeddings',\n",
       " 'get_parameter',\n",
       " 'get_position_embeddings',\n",
       " 'get_submodule',\n",
       " 'gradient_checkpointing_disable',\n",
       " 'gradient_checkpointing_enable',\n",
       " 'greedy_search',\n",
       " 'group_beam_search',\n",
       " 'half',\n",
       " 'init_weights',\n",
       " 'invert_attention_mask',\n",
       " 'is_gradient_checkpointing',\n",
       " 'is_parallelizable',\n",
       " 'load_state_dict',\n",
       " 'load_tf_weights',\n",
       " 'main_input_name',\n",
       " 'modules',\n",
       " 'name_or_path',\n",
       " 'named_buffers',\n",
       " 'named_children',\n",
       " 'named_modules',\n",
       " 'named_parameters',\n",
       " 'num_parameters',\n",
       " 'parameters',\n",
       " 'post_init',\n",
       " 'prepare_inputs_for_generation',\n",
       " 'prune_heads',\n",
       " 'push_to_hub',\n",
       " 'register_backward_hook',\n",
       " 'register_buffer',\n",
       " 'register_for_auto_class',\n",
       " 'register_forward_hook',\n",
       " 'register_forward_pre_hook',\n",
       " 'register_full_backward_hook',\n",
       " 'register_parameter',\n",
       " 'requires_grad_',\n",
       " 'reset_memory_hooks_state',\n",
       " 'resize_position_embeddings',\n",
       " 'resize_token_embeddings',\n",
       " 'retrieve_modules_from_names',\n",
       " 'sample',\n",
       " 'save_pretrained',\n",
       " 'set_extra_state',\n",
       " 'set_input_embeddings',\n",
       " 'set_output_embeddings',\n",
       " 'share_memory',\n",
       " 'state_dict',\n",
       " 'supports_gradient_checkpointing',\n",
       " 'tie_weights',\n",
       " 'to',\n",
       " 'to_empty',\n",
       " 'train',\n",
       " 'training',\n",
       " 'type',\n",
       " 'xpu',\n",
       " 'zero_grad']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d939143d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "880d79a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_8layers = BertConfig.from_pretrained('bert-base-multilingual-cased', \n",
    "                                             num_hidden_layers = 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4adbb752",
   "metadata": {},
   "outputs": [],
   "source": [
    "half = BertForMaskedLM(config_8layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "150eddb0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
       "  (embeddings): BertEmbeddings(\n",
       "    (word_embeddings): Embedding(119547, 768, padding_idx=0)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
       "    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0.1, inplace=False)\n",
       "  )\n",
       "  (encoder): BertEncoder(\n",
       "    (layer): ModuleList(\n",
       "      (0): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (1): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (2): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (3): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (4): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (5): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (6): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (7): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (8): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (9): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (10): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "      (11): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (output): BertSelfOutput(\n",
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "        (intermediate): BertIntermediate(\n",
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (intermediate_act_fn): GELUActivation()\n",
       "        )\n",
       "        (output): BertOutput(\n",
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3416d6f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "def copy_layers(src_layers, dest_layers, layers_to_copy):\n",
    "    layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])\n",
    "    assert len(dest_layers) == len(layers_to_copy), f\"{len(dest_layers)} != {len(layers_to_copy)}\"\n",
    "    dest_layers.load_state_dict(layers_to_copy.state_dict())\n",
    "\n",
    "layers_to_copy = [0,1,2,3,4,5,6,7]\n",
    "\n",
    "copy_layers(model.bert.encoder.layer, half.bert.encoder.layer, layers_to_copy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a76e65ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "half.bert.embeddings.load_state_dict(model.bert.embeddings.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9064ff6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "half.save_pretrained(\"bert-multilanguage-8layers\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9b04b986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('bert-multilanguage-8layers/tokenizer_config.json',\n",
       " 'bert-multilanguage-8layers/special_tokens_map.json',\n",
       " 'bert-multilanguage-8layers/vocab.txt',\n",
       " 'bert-multilanguage-8layers/added_tokens.json')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(\"bert-multilanguage-8layers\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
